# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
# and https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505

import operator
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from functools import partial, reduce
from typing import Any, Callable, List, Union

import torch
from torch.utils._pytree import tree_map

aten = torch.ops.aten

class Phase(Enum):
    FWD = auto()
    BWD = auto()


def _format_flops(flops: float) -> str:
    """Returns a formatted flops string"""
    if flops > 1e12:
        return f'{flops / 1e12:.2f} TFLOPs'
    elif flops > 1e9:
        return f'{flops / 1e9:.2f} GFLOPs'
    elif flops > 1e6:
        return f'{flops / 1e6:.2f} MFLOPs'
    elif flops > 1e3:
        return f'{flops / 1e3:.2f} kFLOPs'
    return f'{flops} FLOPs'


def flop_count(module: Union[torch.nn.Module, Callable], *args, verbose: bool = False, forward_only: bool = True, **kwargs) -> int:
    """
    Count the number of floating point operations in a model.
    Ideas from https://pastebin.com/AkvAyJBw.
    Parameters
    ----------
    module : Union[torch.nn.Module, Callable]
        The model to count the number of floating point operations.
    args : Any
        The positional arguments to pass to the model.
    verbose : bool
        Whether to print the number of floating point operations.
    forward_only : bool
        Whether to only count the number of floating point operations in the forward pass.
    kwargs : Any
        The keyword arguments to pass to the model.

    Returns
    -------
    int
        The number of floating point operations.
    """
    maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
                     or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))

    class DummyModule(torch.nn.Module):

        def __init__(self, func):
            super().__init__()
            self.func = func
            self.__name__ = func.__name__

        def forward(self, *args, **kwargs):
            return self.func(*args, **kwargs)

    total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
    flop_counts = defaultdict(lambda: defaultdict(int))
    parents = ['Global']
    module = module if isinstance(module, torch.nn.Module) else DummyModule(module)

    class FlopTensor(torch.Tensor):
        elem: torch.Tensor

        __slots__ = ['elem']

        @staticmethod
        def __new__(cls, elem):
            # The wrapping tensor (FlopTensor) shouldn't hold any
            # memory for the class in question, but it should still
            # advertise the same device as before
            r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
                cls, elem.size(),
                strides=elem.stride(), storage_offset=elem.storage_offset(),
                # TODO: clone storage aliasing
                dtype=elem.dtype, layout=elem.layout,
                device=elem.device, requires_grad=elem.requires_grad
            )
            # ...the real tensor is held as an element on the tensor.
            r.elem = elem
            return r

        def __repr__(self):
            if self.grad_fn:
                return f"FlopTensor({self.elem}, grad_fn={self.grad_fn})"
            return f"FlopTensor({self.elem})"

        @classmethod
        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
            def unwrap(e):
                return e.elem if isinstance(e, FlopTensor) else e

            # no_dispatch is only needed if you use enable_python_mode.
            # It prevents infinite recursion.
            rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
            outs = normalize_tuple(rs)

            if func in flop_mapping:
                nonlocal flop_counts, total_flop_count, cur_phase
                flop_count = flop_mapping.get(func, zero_flop_jit)(args, outs)
                for par in parents:
                    flop_counts[par][func.__name__] += flop_count
                total_flop_count[cur_phase] += flop_count

            def wrap(e):
                return FlopTensor(e) if isinstance(e, torch.Tensor) else e

            rs = tree_map(wrap, rs)
            return rs


    def is_autogradable(x):
        return isinstance(x, torch.Tensor) and x.is_floating_point()

    def normalize_tuple(x):
        if not isinstance(x, tuple):
            return (x,)
        return x

    def create_backwards_push(name):

        class PushState(torch.autograd.Function):

            @staticmethod
            def forward(ctx, *args):
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
                if len(args) == 1:
                    return args[0]
                return args

            @staticmethod
            def backward(ctx, *grad_outs):
                nonlocal parents
                parents.append(name)
                return grad_outs

        return PushState.apply

    def create_backwards_pop(name):

        class PopState(torch.autograd.Function):

            @staticmethod
            def forward(ctx, *args):
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
                if len(args) == 1:
                    return args[0]
                return args

            @staticmethod
            def backward(ctx, *grad_outs):
                nonlocal parents
                assert (parents[-1] == name)
                parents.pop()
                return grad_outs

        return PopState.apply

    def enter_module(name):

        def f(module, inputs):
            nonlocal parents
            parents.append(name)
            inputs = normalize_tuple(inputs)
            out = create_backwards_pop(name)(*inputs)
            return out

        return f

    def exit_module(name):

        def f(module, inputs, outputs):
            nonlocal parents
            assert (parents[-1] == name)
            parents.pop()
            outputs = normalize_tuple(outputs)
            return create_backwards_push(name)(*outputs)

        return f

    @contextmanager
    def instrument_module(mod):
        registered = []
        for name, module in dict(mod.named_children()).items():
            registered.append(module.register_forward_pre_hook(enter_module(name)))
            registered.append(module.register_forward_hook(exit_module(name)))
        yield
        for handle in registered:
            handle.remove()

    def display_flops():
        for mod in flop_counts.keys():
            print(f"Module: ", mod)
            for k, v in flop_counts[mod].items():
                print('\t', k, _format_flops(v))
            print()

    def detach_variables(r):
        if isinstance(r, torch.Tensor):
            requires_grad = r.requires_grad
            r = r.detach()
            r.requires_grad = requires_grad
        return r

    def wrap(r):
        if isinstance(r, torch.Tensor):
            r = FlopTensor(detach_variables(r))
            if maybe_inplace:
                r = r + 0
        return r

    with instrument_module(module):
        cur_phase = Phase.FWD
        rst = module(*tree_map(wrap, args), **tree_map(wrap, kwargs))
        rst = tuple(r for r in normalize_tuple(rst) if is_autogradable(r) and r.requires_grad)
        cur_phase = Phase.BWD

        if rst and not forward_only:
            grad = [torch.zeros_like(t) for t in rst]
            torch.autograd.backward(
                rst,
                grad,
            )

    if verbose:
        display_flops()

    if forward_only:
        return total_flop_count[Phase.FWD]
    else:
        return total_flop_count[Phase.FWD], total_flop_count[Phase.BWD]


def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> int:
    """
    Count flops for matmul.
    """
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two matrices.
    input_shapes = [v.shape for v in inputs]
    assert len(input_shapes) == 2, input_shapes
    assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
    flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
    return flops


def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> int:
    """
    Count flops for fully connected layers.
    """
    # Count flop for nn.Linear
    # inputs is a list of length 3.
    input_shapes = [v.shape for v in inputs[1:3]]
    # input_shapes[0]: [batch size, input feature dimension]
    # input_shapes[1]: [input feature dimension, output feature dimension]
    assert len(input_shapes[0]) == 2, input_shapes[0]
    assert len(input_shapes[1]) == 2, input_shapes[1]
    batch_size, input_dim = input_shapes[0]
    output_dim = input_shapes[1][1]
    flops = batch_size * input_dim * output_dim
    return flops


def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> int:
    """
    Count flops for the aten::linear operator.
    """
    # Inputs is a list of length 3; unlike aten::addmm, it is the first
    # two elements that are relevant.
    input_shapes = [v.shape for v in inputs[0:2]]
    # input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
    # input_shapes[1]: [output_feature_dim, input_feature_dim]
    assert input_shapes[0][-1] == input_shapes[1][-1]
    flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0]
    return flops


def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> int:
    """
    Count flops for the bmm operation.
    """
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two tensor.
    assert len(inputs) == 2, len(inputs)
    input_shapes = [v.shape for v in inputs]
    n, c, t = input_shapes[0]
    d = input_shapes[-1][-1]
    flops = n * c * t * d
    return flops


def conv_flop_count(
    x_shape: List[int],
    w_shape: List[int],
    out_shape: List[int],
    transposed: bool = False,
) -> int:
    """
    Count flops for convolution. Note only multiplication is
    counted. Computation for addition and bias is ignored.
    Flops for a transposed convolution are calculated as
    flops = (x_shape[2:] * prod(w_shape) * batch_size).
    Args:
        x_shape (list(int)): The input shape before convolution.
        w_shape (list(int)): The filter shape.
        out_shape (list(int)): The output shape after convolution.
        transposed (bool): is the convolution transposed
    Returns:
        int: the number of flops
    """
    batch_size = x_shape[0]
    conv_shape = (x_shape if transposed else out_shape)[2:]
    flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape)
    return flops


def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
    """
    Count flops for convolution.
    """
    x, w = inputs[:2]
    x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
    transposed = inputs[6]

    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)


def transpose_shape(shape):
    return [shape[1], shape[0]] + list(shape[2:])


def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
    grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
    output_mask = inputs[-1]
    fwd_transposed = inputs[7]
    flop_count = 0

    if output_mask[0]:
        grad_input_shape = outputs[0].shape
        flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
    if output_mask[1]:
        grad_weight_shape = outputs[1].shape
        flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)

    return flop_count


def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
    """
    Args:
        affine_arg_index: index of the affine argument in inputs
    """

    def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> int:
        """
        Count flops for norm layers.
        """
        # Inputs[0] contains the shape of the input.
        input_shape = inputs[input_arg_index].shape

        has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
                                                                           'shape') else inputs[affine_arg_index]
        assert 2 <= len(input_shape) <= 5, input_shape
        # 5 is just a rough estimate
        flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
        return flop

    return norm_flop_jit


def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> int:
    if training is None:
        training = inputs[-3]
    assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
    if training:
        return norm_flop_counter(1, 0)(inputs, outputs)    # pyre-ignore
    has_affine = inputs[1].shape is not None
    input_shape = reduce(operator.mul, inputs[0].shape)
    return input_shape * (2 if has_affine else 1)


def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable:
    """
    Count flops by
        input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
    Args:
        input_scale: scale of the input tensor (first argument)
        output_scale: scale of the output tensor (first element in outputs)
    """

    def ewise_flop(inputs: List[Any], outputs: List[Any]) -> int:
        ret = 0
        if input_scale != 0:
            shape = inputs[0].shape
            ret += input_scale * reduce(operator.mul, shape) if shape else 0
        if output_scale != 0:
            shape = outputs[0].shape
            ret += output_scale * reduce(operator.mul, shape) if shape else 0
        return ret

    return ewise_flop


def zero_flop_jit(*args):
    """
        Count flops for zero flop layers.
    """
    return 0


# pyright: reportGeneralTypeIssues=false
flop_mapping = {
    # gemm
    aten.mm.default: matmul_flop_jit,
    aten.matmul.default: matmul_flop_jit,
    aten.addmm.default: addmm_flop_jit,
    aten.bmm.default: bmm_flop_jit,

    # convolution
    aten.convolution.default: conv_flop_jit,
    aten._convolution.default: conv_flop_jit,
    aten.convolution_backward.default: conv_backward_flop_jit,

    # normalization
    aten.native_batch_norm.default: batchnorm_flop_jit,
    aten.native_batch_norm_backward.default: batchnorm_flop_jit,
    aten.cudnn_batch_norm.default: batchnorm_flop_jit,
    aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
    aten.native_layer_norm.default: norm_flop_counter(2, 0),
    aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),

    # pooling
    aten.avg_pool1d.default: ewise_flop_counter(1, 0),
    aten.avg_pool2d.default: ewise_flop_counter(1, 0),
    aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
    aten.avg_pool3d.default: ewise_flop_counter(1, 0),
    aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
    aten.max_pool1d.default: ewise_flop_counter(1, 0),
    aten.max_pool2d.default: ewise_flop_counter(1, 0),
    aten.max_pool3d.default: ewise_flop_counter(1, 0),
    aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
    aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
    aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
    aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
    aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
    aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
    aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
    aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
    aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
    aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
    aten.embedding.default: ewise_flop_counter(1, 0),
}

ewise_flop_aten = [
    # basic op
    aten.add.Tensor,
    aten.add_.Tensor,
    aten.div.Tensor,
    aten.div_.Tensor,
    aten.div.Scalar,
    aten.div_.Scalar,
    aten.mul.Tensor,
    aten.mul.Scalar,
    aten.mul_.Tensor,
    aten.neg.default,
    aten.pow.Tensor_Scalar,
    aten.rsub.Scalar,
    aten.sum.default,
    aten.sum.dim_IntList,
    aten.mean.dim,

    # activation op
    aten.hardswish.default,
    aten.hardswish_.default,
    aten.hardswish_backward.default,
    aten.hardtanh.default,
    aten.hardtanh_.default,
    aten.hardtanh_backward.default,
    aten.hardsigmoid_backward.default,
    aten.hardsigmoid.default,
    aten.gelu.default,
    aten.gelu_backward.default,
    aten.silu.default,
    aten.silu_.default,
    aten.silu_backward.default,
    aten.sigmoid.default,
    aten.sigmoid_backward.default,
    aten._softmax.default,
    aten._softmax_backward_data.default,
    aten.relu_.default,
    aten.relu.default,
    aten.tanh.default,
    aten.tanh_backward.default,
    aten.threshold_backward.default,

    # dropout
    aten.native_dropout.default,
    aten.native_dropout_backward.default,

    # distribution
    aten.bernoulli_.float,

    # where
    aten.where.self,
]
for op in ewise_flop_aten:
    flop_mapping[op] = ewise_flop_counter(1, 0)
